# models
import shutil
import sys

from tensorboardX import SummaryWriter
from torch import optim

from myloss import *
import torch.nn.modules.loss as Loss
from unet.unet_model import Net
from utils.cmplxBatchNorm import magnitude, normalizeComplexBatch_byMagnitudeOnly
from utils.fftutils import *
from utils.data_vis import ntensorshow, combine_coils_RSOS
import traceback
from utils.dataset_radcine import *


######################################3
# set seed points
seed_num = 888

torch.manual_seed(seed_num)
torch.cuda.manual_seed_all(seed_num)
np.random.seed(seed_num)
params = Parameters()

####################################
#
# Create Data Generators
#
####################################
training_DG, validation_DG, params = getDatasetGenerators(params)


####################################
#
# Create Model
#
####################################

net = Net(params.n_channels, 1)

def multiply_elems(x):
    m = 1
    for e in x:
        m *= e
    return m

num_params = 0
for parameters in net.parameters():
    num_params += multiply_elems(parameters.shape)
print('Total number of parameters: {0}'.format(num_params))

if params.multi_GPU:
    net = torch.nn.DataParallel(net, device_ids=params.device_ids[:-1]).cuda()
else:
    net.to(params.device)

def normalize_weight(self, x):
    x = torch.reshape(x, [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]])
    return x / (x.shape[-1] * x.shape[-2])



####################################
#
# initializations
#
####################################


if not os.path.exists(params.model_save_dir):
    os.makedirs(params.model_save_dir)

if not os.path.exists(params.tensorboard_dir):
    os.makedirs(params.tensorboard_dir)

writer = SummaryWriter(params.tensorboard_dir)


def train(net):
    ###########################################
    #
    # INITIALIZATIONS
    #
    ############################################
    optimizer = optim.SGD(net.parameters(), lr=params.args.lr, momentum=0.9)

    LOSS = list()
    ssimCriterion = SSIM(window_size=100) if params.network_type == '2D' else SSIM_3D(window_shape=(11, 11, params.num_slices_3D // 2))
    mseCriterion = Loss.MSELoss()

    vld_MSE_LOSS = list()
    vld_SSIM_LOSS = list()
    vld_PSNR_LOSS = list()

    vi = 0
    i = 0

    ###########################################
    #
    # LOAD LATEST (or SPECIFIC) MODEL
    #
    ############################################
    models = os.listdir(params.model_save_dir)
    models = [m for m in models if m.endswith('.pth')]
    s_epoch = -1 ## -1: load latest model or start from 1 if there is no saved models
                ##  0: don't load any model; start from model #1
                ##  num: load model #num

    def load_model(epoch):
        print('loading model at epoch ' + str(epoch))
        model = torch.load(params.model_save_dir + models[0][0:11] + str(epoch) + '.pth')
        net.load_state_dict(model['state_dict'])
        optimizer.load_state_dict(model['optimizer'])
        try:
            LOSS = model['loss']
            i = model['iteration']
            LOSS = loadmat('{0}mse_R{1}_Trial{2}'.format(params.tensorboard_dir, str(params.Rate), params.trialNum))['mse']
            SSIMLOSS = loadmat('{0}ssim_R{1}_Trial{2}'.format(params.tensorboard_dir, str(params.Rate), params.trialNum))['ssim']
            vld_MSE_LOSS = loadmat('{0}vld_mse_R{1}_Trial{2}'.format(params.tensorboard_dir, str(params.Rate), params.trialNum))
            vld_SSIM_LOSS = loadmat('{0}vld_ssim_R{1}_Trial{2}'.format(params.tensorboard_dir, str(params.Rate), params.trialNum))
        except:
            pass
    print(len(models))

    if s_epoch == -1:
        if len(models) == 0:
            s_epoch = 1
        else:
            s_epoch = max([int(epo[11:-4]) for epo in models[:]])
            load_model(s_epoch)
    elif s_epoch == 0:
        s_epoch = 1
    else:
        try:
            load_model(s_epoch)
        except:
            print('Model {0} does not exist!'.format(s_epoch))

    for epoch in range(s_epoch, params.epochs+1):
        print('epoch {}/{}...'.format(epoch, params.epochs))

        adjust_learning_rate(optimizer, epoch)

        ###########################################
        #
        # Training
        #
        ############################################
        l = 0
        itt = 0
        TAG = 'Training'
        if not params.Validation_Only:
            for X, y, sliceID in training_DG:
                try:

                    X = Variable(torch.FloatTensor(X.float())).to('cuda:0')
                    y = Variable(torch.FloatTensor(y.float())).to('cuda:0')

                    if False and not params.device == 'spider' and itt == 0:
                        writer.add_graph(net, X)

                    X = torch.ifft(X, 2, normalized=False)
                    X = normalizeComplexBatch_byMagnitudeOnly(X, normalize_over_channel=True)
                    X = fftshift2d(torch.fft(X, 2, normalized=True), [3, 4])

                    y = normalizeComplexBatch_byMagnitudeOnly(y, normalize_over_channel=True)

                    y_pred = net(X.to('cuda:0')).to('cuda:0')

                except Exception as e:
                    traceback.print_exc()
                    continue

                if i % 200 == 0:
                    ifnames = ['Image_epoch_{0}_iter_{1}_sl_{2}'.format(epoch, i, s) for s in range(0, params.batch_size)]

                    fignames = ['ptn{0}_{1}'.format(s.split('/')[-2][:], s.split('/')[-1][:-5]) for s in sliceID]
                    nuft = torch.ifft(X[:,:,params.moving_window_size//2,:,:,:].squeeze(2), 2, normalized=True)   #fftshift2d(torch.ifft(X, 2), [2, 3])
                    ntensorshow((nuft, y, y_pred), (0, 0), (0, 3), ('NUFFT', 'Ref', 'GridNet'), saveFigs = True, figname=ifnames)

                loss = mseCriterion(y_pred, y)

                LOSS.append(loss.cpu().data.numpy())

                l += loss.data[0]

                optimizer.zero_grad()
                loss.backward()
                i += 1
                optimizer.step()
                print('Epoch: {0} - {1:.3f}%'.format(epoch, 100 * (itt * params.batch_size) / len(
                    training_DG.dataset.input_IDs))
                      + ' \tIter: ' + str(i)
                      + '\tLoss: {0:.6f}'.format(loss.data[0])
                      )
                itt += 1
                is_best = 0
                if i % 50 == 0:

                    save_checkpoint({'epoch': epoch, 'loss': LOSS, 'arch': 'recoNet_Model1', 'state_dict': net.state_dict(),
                                    'optimizer': optimizer.state_dict(), 'iteration': i,
                                    }, is_best, filename=params.model_save_dir + 'MODEL_EPOCH{}.pth'.format(epoch))

                if True or params.tbVisualize:
                    writer.add_scalar(TAG + '/' + 'avg_SME', l / itt, epoch)
                    saveArrayToMat(LOSS, 'mse',
                                   'mse_R{0}_Trial{1}'.format(str(params.Rate), params.trialNum), params.tensorboard_dir)

            avg_loss = params.batch_size * l / len(training_DG.dataset.input_IDs)
            print('Total Loss : {0:.6f} \t Avg. Loss {1:.6f}'.format(l, avg_loss))

            save_checkpoint({'epoch': epoch, 'loss': LOSS, 'arch': 'recoNet_Model1', 'state_dict': net.state_dict(),
                             'optimizer': optimizer.state_dict(), 'iteration': i,
                             }, is_best, filename=params.model_save_dir + 'MODEL_EPOCH{}.pth'.format(epoch))

        else:
            load_model(epoch)


        #####################################
        #
        # Validation
        #
        #####################################

        vitt = 0
        vld_mse = 0
        vld_ssim = 0
        vld_psnr = 0

        if not params.Validation_Only: #(params.Validation_Only or (epoch < 100 and epoch % 5 > 0)):
            continue
        TAG = 'Validation'
        with torch.no_grad():
            for X, y, sliceID in validation_DG:
                patient_names = [pn.split(sep='/')[-2] for pn in sliceID]
                try:
                    X = Variable(torch.FloatTensor(X.float())).to('cuda:0')
                    y = Variable(torch.FloatTensor(y.float())).to('cuda:0')

                    if False and not params.device == 'spider' and itt == 0:
                        writer.add_graph(net, X)

                    X = torch.ifft(X, 2, normalized=False)
                    X = normalizeComplexBatch_byMagnitudeOnly(X, normalize_over_channel=True)
                    X = fftshift2d(torch.fft(X, 2, normalized=True), [3, 4])

                    y = normalizeComplexBatch_byMagnitudeOnly(y, normalize_over_channel=True)

                    y_pred = net(X.to('cuda:0')).to('cuda:0')

                    fignames = ['tst_ptn{0}_{1}'.format(s.split('/')[-2][:], s.split('/')[-1][:-5]) for s in sliceID]

                    nuft = torch.ifft(X[:,:,params.moving_window_size//2,:,:,:].squeeze(2), 2, normalized=True) # fftshift2d(torch.ifft(X, 2), [2, 3])
                    ntensorshow((nuft, y, y_pred), (0, 0), (0, 3), ('NUFFT', 'Ref', 'Network'), saveFigs=True,
                                figname=fignames)

                    y  = combine_coils_RSOS(y)
                    nuft = combine_coils_RSOS(nuft)
                    y_pred = magnitude(y_pred)

                    mseloss = mseCriterion(y_pred, y)
                    ssimloss = ssimCriterion(y_pred, y)
                    psnrloss = 10 * np.log10(torch.max(y_pred) ** 2 / mseloss)

                    save_results = True
                    if save_results:
                        y = torch.squeeze(y)
                        nuft = torch.squeeze(nuft)
                        y_pred = torch.squeeze(y_pred)

                        for idx, sid in enumerate(sliceID):
                            save_url = params.net_save_dir + 'Results_{0}/'.format(params.arch_name) + sid.split('/')[-2]
                            f_name = sid.split('/')[-1][:-4]

                            saveTensorToMat(y[idx,], 'ref', f_name, save_url+'/ref/')
                            saveTensorToMat(y_pred[idx,], 'net', f_name, save_url+'/net/')
                            saveTensorToMat(nuft[idx,], 'nuft', f_name, save_url+'/nuft/')

                    vld_MSE_LOSS.append(mseloss.cpu().data.numpy())
                    vld_SSIM_LOSS.append(ssimloss.cpu().data.numpy())
                    vld_PSNR_LOSS.append(psnrloss.cpu().data.numpy())

                    vld_mse += mseloss.cpu().data[0]
                    vld_ssim += ssimloss.cpu().data[0]
                    vld_psnr += psnrloss.cpu().data[0]

                    vi += 1
                    vitt += 1

                    print('Epoch: {0} - {1:.3f}%'.format(epoch, 100 * (vitt * params.batch_size) / len(validation_DG.dataset.input_IDs))
                          + ' \tIter: ' + str(vi)
                          + '\tSME: {0:.4f}'.format(mseloss.data[0])
                          + '\tSSIM: {0:.6f}'.format(ssimloss.data[0]))
                except Exception as e:
                    traceback.print_exc()
                    continue

            avg_factor = params.batch_size / len(validation_DG.dataset.input_IDs)
            print('Avg. MSE : {0:.6f}'.format(vld_mse * avg_factor)
                  + '\tAvg. SSIM : {0:.6f}'.format(vld_ssim * avg_factor)
                  + '\tAvg. PSNR : {0:.6f}'.format(vld_psnr * avg_factor))

            if True or params.tbVisualize:
                writer.add_scalar(TAG + '/' + 'avg_SME',
                                  params.batch_size * vld_mse / len(validation_DG.dataset.input_IDs), epoch)
                writer.add_scalar(TAG + '/' + 'avg_SSIM',
                                  params.batch_size * vld_ssim / len(validation_DG.dataset.input_IDs), epoch)
                saveArrayToMat(vld_MSE_LOSS, 'vmse', 'vld_mse_R{0}_Trial{1}'.format( params.Rate, params.trialNum),params.tensorboard_dir)
                saveArrayToMat(vld_SSIM_LOSS, 'vssim', 'vld_ssim_R{0}_Trial{1}'.format(params.Rate, params.trialNum), params.tensorboard_dir)

    writer.close()

def ceildiv(a, b):
    return -(-a // b)

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    print('Model Saved!')
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = params.args.lr * (0.1 ** (epoch // 50))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

try:
    train(net)

except KeyboardInterrupt:
    print('Interrupted')
    torch.save(net.state_dict(), 'MODEL_INTERRUPTED.pth')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)